import tensorflow as tf 
import numpy as np 

LR_A = 0.001
LR_C = 0.001
GAMMA = 0.9
TAU = 0.01
MEMORY_CAPACITY = 10000
BATCH_SIZE = 32 


class DDPG(object):

	def __init__(self, a_dim, s_dim, a_bound):
		self.memory = np.zeros((MEMORY_CAPACITY, s_dim*2+a_dim+1), dtype=np.float32)
		self.pointer = 0
		self.memory_full = False
		self.sess = tf.Session()
		self.a_replace_counter, self.c_replace_counter = 0, 0

		self.a_dim, self.s_dim, self.a_bound = a_dim, s_dim, a_bound[1]
		self.S = tf.placeholder(tf.float32, [None, s_dim], 's')
		self.S_ = tf.placeholder(tf.float32, [None, s_dim], 's_')
		self.R = tf.placeholder(tf.float32, [None, 1], 'r')

		with tf.variable_scope('Actor'):
			self.a = self._build_a(self.S, scope='eval', trainable=True)
			a_ = self._build_a(self.S_, scope='target', trainable=False)

		with tf.variable_scope('Critic'):
			q = self._build_c(self.S, self.a, scope='eval', trainable=True)
			q_ = self._build_c(self.S_, a_, scope='target', trainable=False)

		self.ae_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/eval')
		self.at_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/target')
		self.ce_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope ='Critic/eval')	
		self.ct_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/target')

		self.soft_replace = [[tf.assign(ta, (1 - TAU) * ta + TAU * ea), tf.assign(tc, (1 - TAU) * tc + TAU * ec)] 
								for ta, ea, tc, ec in zip(self.at_params, self.ae_params, self.ct_params, self.ce_params)]

		q_target = self.R + GAMMA * q_
		td_error = tf.losses.mean_squared_error(labels=q_target, predictions=q)
		self.ctrain = tf.train.AdamOptimizer(LR_C).minimize(td_error, var_list=self.ce_params)

		a_loss = - tf.reduce_mean(q) #maxmize q
		self.atrain = tf.train.AdamOptimizer(LR_A).minimize(a_loss, var_list=self.ae_params)

		self.sess.run(tf.global_variables_initializer())


	def choose_action(self, s):
		return self.sess.run(self.a, feed_dict={self.S:s[None, :]})[0]	

	def learn(self):
		self.sess.run(self.soft_replace)

		indices = np.random.choice(MEMORY_CAPACITY, size=BATCH_SIZE)
		bt = self.memory[indices, :]
		bs = bt[:, :self.s_dim]
		ba = bt[:, self.s_dim:self.s_dim+self.a_dim]
		br = bt[:, -self.s_dim-1:-self.s_dim]
		bs_ = bt[:, -self.s_dim:]

		self.sess.run(self.atrain, feed_dict={self.S:bs})
		self.sess.run(self.ctrain, feed_dict={self.S:bs, self.a:ba, self.R:br, self.S_:bs_})


	def store_transition(self, s, a, r, s_):
		transition = np.hstack((s, a, [r], s_))
		index = self.pointer % MEMORY_CAPACITY
		self.memory[index: ] = transition
		self.pointer += 1
		if self.pointer > MEMORY_CAPACITY:
			self.memory_full = True

	def _build_a(self, s, scope, trainable):
		with tf.variable_scope(scope):
			net = tf.layers.dense(s, 100, activation=tf.nn.relu, name='l1', trainable=trainable)
			a = tf.layers.dense(net, self.a_dim, activation=tf.nn.tanh, name='a', trainable=trainable)
			return tf.multiply(a, self.a_bound, name='scaled_a')

	def _build_c(self, s, a, scope, trainable):
		with tf.variable_scope(scope):
			n_l1 = 100
			w1_s = tf.get_variable('w1_s', [self.s_dim, n_l1], trainable=trainable)
			w1_a = tf.get_variable('w1_a', [self.a_dim, n_l1], trainable=trainable)
			b1 = tf.get_variable('b1', [1, n_l1], trainable=trainable)
			net = tf.nn.relu(tf.matmul(s, w1_s) + tf.matmul(a, w1_a) + b1)
			return tf.layers.dense(net, 1, trainable=trainable)

	def save(self):
		saver = tf.train.Saver()
		saver.save(self.sess, './params/', write_meta_graph=False)

	def restore(self):
		saver = tf.train.Saver()
		saver.restore(self.sess, './params/')